{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "imagenet_nashlepka_layer4_editable_SGD_momentum_match_rank_2019.09.24_17:00:13\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=SETYOURDEVICEHERE\n",
    "import os, sys, time\n",
    "sys.path.insert(0, '..')\n",
    "import lib\n",
    "\n",
    "import numpy as np\n",
    "import torch, torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import random\n",
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.random.manual_seed(42)\n",
    "\n",
    "import time\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "experiment_name = 'imagenet_nashlepka_layer4_editable_SGD_momentum_match_rank'\n",
    "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n",
    "print(experiment_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "\n",
    "data_path = '../../imagenet/'\n",
    "logits_path = 'imagenet_logits/'\n",
    "\n",
    "traindir = os.path.join(data_path, 'train')\n",
    "valdir = os.path.join(data_path, 'val')\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                 std=[0.229, 0.224, 0.225])\n",
    "\n",
    "train_dataset = lib.ImageAndLogitsFolder(\n",
    "    traindir,\n",
    "    transforms.Compose([\n",
    "        transforms.RandomResizedCrop(224),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        normalize,\n",
    "    ]),\n",
    "    logits_prefix = logits_path\n",
    ")\n",
    "\n",
    "batch_size = 128\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=True,\n",
    "    num_workers=12, pin_memory=True)\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    datasets.ImageFolder(valdir, transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        normalize,\n",
    "    ])),\n",
    "    batch_size=batch_size, shuffle=False,\n",
    "    num_workers=32, pin_memory=True)\n",
    "\n",
    "X_test, y_test = map(torch.cat, zip(*val_loader))\n",
    "X_test, y_test = X_test[::10], y_test[::10] \n",
    "# Note: we use 10% of data for early stopping\n",
    "# We evaluate on all data later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "\n",
    "optimizer = lib.IngraphRMSProp(learning_rate=1e-4, beta=nn.Parameter(torch.as_tensor(0.5)))\n",
    "\n",
    "model = lib.SequentialWithEditable(\n",
    "    model.conv1, model.bn1, model.relu, model.maxpool,\n",
    "    model.layer1, model.layer2, model.layer3, model.layer4,\n",
    "    model.avgpool, lib.Flatten(),\n",
    "    lib.Editable(\n",
    "        lib.Residual(nn.Linear(512, 4096), nn.ELU(), nn.Linear(4096, 512)),\n",
    "        loss_function=lib.contrastive_cross_entropy, \n",
    "        optimizer=optimizer, max_steps=10),\n",
    "\n",
    "    model.fc\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EditableMatchRankTrainer(lib.DistillationEditableTrainer):\n",
    "    def __init__(self, distribution_cumsum, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.cumsum = distribution_cumsum\n",
    "    \n",
    "    def train_on_batch(self, x_batch, y_batch, x_edit, y_edit, **kwargs):\n",
    "        with torch.no_grad(), lib.training_mode(self.model, is_train=False):\n",
    "            logits = self.model(x_edit)\n",
    "            sorted_ans = logits.topk(k=logits.shape[1], dim=1).indices\n",
    "            choices = (torch.rand(sorted_ans.shape[0]).view(-1, 1) > self.cumsum).to(torch.int32).sum(-1)\n",
    "            choices = choices.to(sorted_ans.device)\n",
    "            y_edit = sorted_ans.gather(dim=1, index=choices.view(-1, 1)).view(-1)\n",
    "        super().train_on_batch(x_batch, y_batch, x_edit, y_edit, **kwargs)\n",
    "        \n",
    "    def evaluate_metrics(self, X, y, X_edit=None, y_edit=None, size_top=25, **kwargs):\n",
    "        \"\"\"\n",
    "        For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality\n",
    "        :param X: data for quality evaluaton\n",
    "        :param y: targets for quality evaluaton\n",
    "        :param X_edit: sequence of data for training model on\n",
    "        :param y_edit: sequence of targets for training model on\n",
    "        :param kwargs: extra parameters for error function\n",
    "        :return: dictionary of metrics\n",
    "        \"\"\"\n",
    "        assert (X_edit is None) == (y_edit is None), \"provide either both X_edit and y_edit or none of them\"\n",
    "        if X_edit is None:\n",
    "            num_classes = y.max() + 1\n",
    "            ind = np.random.permutation(len(X))[:10]\n",
    "            X_edit = X[ind]\n",
    "            with torch.no_grad(), lib.training_mode(self.model, is_train=False):\n",
    "                logits = self.model(X_edit)\n",
    "                sorted_ans = logits.topk(k=logits.shape[1], dim=1).indices\n",
    "                choices = (torch.rand(sorted_ans.shape[0]).view(-1, 1) > self.cumsum).to(torch.int32).sum(-1)\n",
    "                choices = choices.to(sorted_ans.device)\n",
    "                y_edit = sorted_ans.gather(dim=1, index=choices.view(-1, 1)).view(-1)\n",
    "        return super().evaluate_metrics(X, y, X_edit, y_edit, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classification_error(model, X_test, y_test):\n",
    "    with lib.training_mode(model, is_train=False):\n",
    "        return lib.classification_error(lib.Lambda(lambda x: model(x.to(device))),\n",
    "                                        X_test, y_test, device='cpu', batch_size=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read natural adversarial examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = [0.485, 0.456, 0.406]\n",
    "std = [0.229, 0.224, 0.225]\n",
    "\n",
    "test_transform = transforms.Compose(\n",
    "    [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)])\n",
    "\n",
    "thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1}\n",
    "two_hundred_to_1000 = dict(map(reversed, thousand_k_to_200.items()))\n",
    "\n",
    "X_adv, y_adv = zip(*datasets.ImageFolder(root=\"./imagenet-a/\", transform=test_transform))\n",
    "X_adv = torch.stack(X_adv)\n",
    "y_adv = torch.tensor(list(map(two_hundred_to_1000.get, y_adv)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD8CAYAAAB3u9PLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl4VOXZ+PHvPTPJZE8IhH0JSJBNQUVEXOsKWsW28hbbWhdau2i1br8Xq7WtrVatLdq+amvrvhRcqFJFUURc2cImayDsIRAChOx7nt8f58xkZjJJhpD93J/r4uLMOc8585wZOPc8uxhjUEoppVwdnQGllFKdgwYEpZRSgAYEpZRSNg0ISimlAA0ISimlbBoQlFJKARoQlFJK2TQgKKWUAjQgKKWUsnk6OgPHolevXiY9Pb2js6GUUl3GqlWrDhlj0iJJ26UCQnp6OpmZmR2dDaWU6jJEZHekabXKSCmlFKABQSmllE0DglJKKUADglJKKZsGBKWUUoAGBKWUUjYNCEoppQCHBYSVu46QdaC4o7OhlFKdUpcamHa8pv99KQC7Hr68g3OilFKdj6NKCEoppRqnAUEppRSgAUEppZRNA4JSSikgwoAgIlNEJEtEskVkVpjjXhGZax9fLiLp9v6eIvKJiJSIyP+FnHOaiKy3z/mriEhr3JBSSqmWaTYgiIgbeBKYCowGrhGR0SHJZgIFxpjhwGzgEXt/BfBr4K4wl34auAnIsP9MackNRMoY05aXV0qpLi+SEsJEINsYs8MYUwXMAaaFpJkGvGhvvwlcKCJijCk1xnyBFRj8RKQfkGSMWWqsJ/VLwFXHcyPNqanTgKCUUk2JJCAMAPYGvM6x94VNY4ypAQqBns1cM6eZa7aq6to6/7aWFpRSqqFIAkK4uv3QJ2okaVqUXkRuEpFMEcnMz89v4pJN+/unO/zb1bUaEJRSKlQkASEHGBTweiCQ21gaEfEAycCRZq45sJlrAmCMecYYM8EYMyEtLaJlQRsoLKvmrx9v878ur65t0XWUUqo7iyQgrAQyRGSoiEQDM4D5IWnmA9fZ21cDi00T9TLGmP1AsYhMsnsX/RB455hzHyETUvio1ICglFINNDuXkTGmRkRuARYCbuA5Y8xGEXkAyDTGzAeeBV4WkWysksEM3/kisgtIAqJF5CrgEmPMJuBnwAtALPC+/adNhDYoawlBKaUaimhyO2PMAmBByL77A7YrgOmNnJveyP5MYGykGT0etSEBoaK6rpGUSinlXI4YqawlBKWUap4zAkJtcImgvEoDglJKhXJGQAitMqrRgKCUUqEcERBC2xBqdByCUko14IiA4AsAo/olAVBbp43KSikVyhkBwQ4A08b3B3SkslJKheOQgGAFgBiPdbuhVUhKKaUcEhB8AcAb5QaCJ7pTSillcURA8AWAmCjrdnUqbKWUasgRAaHWX2VklRA0ICilVEOOCAj+NgS7yih0oJpSSimHBITaWl8bgjYqK6VUYxwREHzdTmP8jcoaEJRSKpRDAkJIG4JWGSmlVAOOCAi+KqJoj/YyUkqpxjgiINTZi7e5XYLbJf4qJKWUUvUcERB8i3m6BDwu0RKCUkqF4YiAUOcPCGIFBG1UVkqpBhwSEOoDgMft0kZlpZQKwxEBAV8JwSV4PS5dQlMppcJwREDwlRBcAn2TY9hfWNHBOVJKqc7HIQHB+tslwoCUWPYdLe/YDCmlVCfkkIBgRQQBUuKiKK6o6dgMKaVUJ+SIgOBrUhYRPC5tVFZKqXCcERAC2hA8bh2HoJRS4TgiINTZAUACxiF8tCmP4orqDs6ZUkp1Ho4ICL7ygFVCsLqd/vilTO54fV2H5ksppToTRwQEXw2Rr4Tgs/NQaQflSCmlOh9HBARfG4IIeFyuBvuVUkpFGBBEZIqIZIlItojMCnPcKyJz7ePLRSQ94Ng99v4sEbk0YP/tIrJRRDaIyL9FJKY1biic+oFpgsctzaRWSilnajYgiIgbeBKYCowGrhGR0SHJZgIFxpjhwGzgEfvc0cAMYAwwBXhKRNwiMgC4FZhgjBkLuO10bSJ0tlOllFINRVJCmAhkG2N2GGOqgDnAtJA004AX7e03gQtFROz9c4wxlcaYnUC2fT0ADxArIh4gDsg9vltpXNBsp25H1JIppdQxi+TpOADYG/A6x94XNo0xpgYoBHo2dq4xZh/wGLAH2A8UGmM+bMkNRCJwttPAAoK2ICilVL1IAkK4OpbQZ2ljacLuF5EeWKWHoUB/IF5EfhD2zUVuEpFMEcnMz8+PILuNc4lQraOUlVIqrEgCQg4wKOD1QBpW7/jT2FVAycCRJs69CNhpjMk3xlQD84DJ4d7cGPOMMWaCMWZCWlpaBNltyDcwzSVQVRMQELSIoJRSfpEEhJVAhogMFZForMbf+SFp5gPX2dtXA4uN1adzPjDD7oU0FMgAVmBVFU0SkTi7reFCYPPx3054geMQKgMCQoWui6CUUn6e5hIYY2pE5BZgIVZvoOeMMRtF5AEg0xgzH3gWeFlEsrFKBjPsczeKyOvAJqAGuNkYUwssF5E3gdX2/jXAM61/e/Y9EL6EUFqlAUEppXyaDQgAxpgFwIKQffcHbFcA0xs590HgwTD7fwP85lgy21KNlRAKy6vZeaiUob3i2yMbSinVqTmiD6YxhsaGH8xbndO+mVFKqU7KEQGhzhispgr45UUZnJPRizd+eiZej4ucAl09TSmlIMIqo67OmPrxBylx0bw88wwAxvRPIvtgCfsLy+mXHNuBOVRKqY7nkBIC/hJCoF4JXtbvK2TGM8s6IFdKKdW5OCIgGGPCjpDz2X24rN3yopRSnZUzAgLWKOVQd196IgBunfBOKaWcERDq6sL3Msrok8j1k9OJi3K3f6aUUqqTcUZAMOFLCADeKBeVOr+RUko5JSCY8NPsAV63i6qaOl09TSnleI4ICNB4CSHaY30E1bUaEJRSzuaIgFDXxEhlX0Co0mojpZTDOSYghBuHABBtr6AWNC22Uko5kCMCQuBI5VDRHquHkQYEpZTTOSIgNDZSGcDr0RKCUkqBQwJCU7Od+toQKmt0bQSllLM5IiDUGYM00u+0PiBoCUEp5WyOCAhNtyFoLyOllAKHBIQm2xC0l5FSSgEOCQjGGBqJB/UlBA0ISimHc0ZAoIm5jLTbqVJKAQ4JCDpSWSmlmueQgNB4G4JWGSmllMURASGSNgQdh6CUcjpnBIQmjulcRkopZXFEQIBGl0PQgWlKKWVzRkBooogQF+3G7RKOllW3X36UUqoTckRAMDQ+/XWU20V6zzi25hW3c66UUqpzcURAgMarjAAGp8aRW1jebnlRSqnOyBEBobnlkhNjoiiuqGmfzCilVCcVUUAQkSkikiUi2SIyK8xxr4jMtY8vF5H0gGP32PuzROTSgP0pIvKmiGwRkc0icmZr3FA4xtBot1OAxBiPBgSllOM1GxBExA08CUwFRgPXiMjokGQzgQJjzHBgNvCIfe5oYAYwBpgCPGVfD+AJ4ANjzEhgHLD5+G+niftootLIKiFUY5orSiilVDcWSQlhIpBtjNlhjKkC5gDTQtJMA160t98ELhSrFXcaMMcYU2mM2QlkAxNFJAk4F3gWwBhTZYw5evy3E55pciQCJMV6qK412vVUKeVokQSEAcDegNc59r6waYwxNUAh0LOJc4cB+cDzIrJGRP4lIvHh3lxEbhKRTBHJzM/PjyC74TVVZZQSGw3A4dKqFl9fKaW6ukgCQrhHaehP7sbSNLbfA5wKPG2MOQUoBRq0TQAYY54xxkwwxkxIS0uLILvhrtH08T5JXgAOFlW06PpKKdUdRBIQcoBBAa8HArmNpRERD5AMHGni3Bwgxxiz3N7/JlaAaBPNtQz0SYoBIK+osq2yoJRSnV4kAWElkCEiQ0UkGquReH5ImvnAdfb21cBiY7XQzgdm2L2QhgIZwApjzAFgr4icaJ9zIbDpOO+lSY0NTANIjbeqjArKtMpIKeVcnuYSGGNqROQWYCHgBp4zxmwUkQeATGPMfKzG4ZdFJBurZDDDPnejiLyO9bCvAW42xvimFf0F8KodZHYAN7TyvQXcQ9PHY6Osjk8V1TrjqVLKuZoNCADGmAXAgpB99wdsVwDTGzn3QeDBMPvXAhOOJbMtZ5ocqRxjB4RyDQhKKQdzxEhlaLqXkdee8bSiWrudKqWcyxEBobkqI5dLiIlyaZWRUsrRHBEQoOkSAljVRhoQlFJO5oiAEMmEFLFRbsqrNCAopZzLGQHBmCbnMgKrhKCNykopJ3NEQIBIq4y0UVkp5VyOCAiRVRlpo7JSytmcERBM0yumgTYqK6WUIwIC0GydUay2ISilHM4RASGSKqOYaA0ISilnc0RAgAiqjDxuKu1GZWMMJZW6pKZSylkcERAiWRozNtrlLyH89JVVnPb7j6jSFdSUUg7iiIAAEXQ79dQ3Ki/cmEdlTR3FFdXtkDOllOocnBMQmjkea7ch1NXVlya02kgp5SSOCAgR1BgR7/VgTPAU2BoQlFJO4oyAgGlyxTSwAgJAaUAQKK3UXkdKKedwRECA5quM4qOtRXJKAya4u/ONtW2YI6WU6lwcERAirTKC4BLC3iPlbZUlpZTqdBwREKD5XkYJdkB4eelu/77xg1IAqKszvLJst05toZTq1hwRECIpIfgCwtzMvf59R8uqAPjv17nc9/YGnvwku03yp5RSnYEzAgLNr4fQIy466HW0x8XhEisgHC2zxiMU2AFCKaW6I0cEBKDZVuXUhOCAkNE7geLKGtbnFFJrj01wN1fvpJRSXZgjAkJEjcp2LyOfE/smAnDH62upsy/gcmlAUEp1X84ICDTf7VRE/I3IAJeN7QfAyH5JVNVacxppCUEp1Z05IiBA872MAOb9bLJ/Oyk2imFp8dQZQ1G51RW1NpKihlJKdVHOCAgRPsddLsFtVwvFRrmJj/ZQVlnD4ZJKAMp05LJSqhtzRkCAZnsZ+fSyG5e9US7iot2UVtVyuNTqXVRapXMbKaW6L0cEBBNpEQGYMqYvYE2HHe/1UF5VyyG7hFCqk90ppboxT0dnoD0YE1kbAsCvvzma6RMGMbhnnFVCqKyh0l4oJ3CeI6WU6m4iKiGIyBQRyRKRbBGZFea4V0Tm2seXi0h6wLF77P1ZInJpyHluEVkjIu8e7400fw+RpfO4XYwdkAxAanw0h0urtISglHKEZgOCiLiBJ4GpwGjgGhEZHZJsJlBgjBkOzAYesc8dDcwAxgBTgKfs6/ncBmw+3ptoTkv7BvVO9FJYXu0vIZRpCUEp1Y1FUkKYCGQbY3YYY6qAOcC0kDTTgBft7TeBC8VagGAaMMcYU2mM2Qlk29dDRAYClwP/Ov7baJoxzU9dEU7vpBj/drTHxc5Dpfzt421Bq6oppVR3EUlAGADsDXidY+8Lm8YYUwMUAj2bOfdx4P8BTa5kLyI3iUimiGTm5+dHkN3GrnPs5yTHRvm3h6TGAfDnj7ayt6CsxflQSqnOKpKAEO5RGvoTubE0YfeLyDeBg8aYVc29uTHmGWPMBGPMhLS0tOZz20hGWiI2qr52a7AdEACqapqMYUop1SVFEhBygEEBrwcCuY2lEREPkAwcaeLcs4ArRWQXVhXUBSLySgvy36biAuY36pXg9W9XakBQSnVDkQSElUCGiAwVkWisRuL5IWnmA9fZ21cDi40xxt4/w+6FNBTIAFYYY+4xxgw0xqTb11tsjPlBK9xPWC2dcSImoISQGFPfQ9c3t5FSSnUnzY5DMMbUiMgtwELADTxnjNkoIg8AmcaY+cCzwMsiko1VMphhn7tRRF4HNgE1wM3GmHbvqmOwJq87VrEhM6D6aJWRUqo7imhgmjFmAbAgZN/9AdsVwPRGzn0QeLCJay8BlkSSj+PRknlKA6uMAuPJg+9tZkd+CRsfmHL8GVNKqU7CEVNXtLTOKLBR+eZvDGdM/yQA1u8rpLSqlsoaHZeglOo+HBEQrCqjYz8vLtoqQF1+cj9S4qJ5bPq4oOMFpdWtkDullOocHDGXEbSsyija4+KrWRfQO9Hrfx3ocGklfZNjwp2qlFJdjiMCwvGsa9M/Jda/He0ODghH7GmxlVKqO3BElRG0rJdRKG9oCaFEA4JSqvtwREA4lvUQmuL1BHdDfXzR1la5rlJKdQbOCAimZW0IoZJig2vYdh3WOY2UUt2HIwICtKyXUcNrhL/IK8t2s2zH4eN/A6WU6kDaqNxC3zplAP9Zs4/L//o5G3OLANj18OWt/0ZKKdVOHFFCsOJBa1QawS3fGA5Ar4RoAH8wUEqprs4RAQFap8oI4M5LRpD94FQSvFENjhlj2JRbxEtLd7XOmymlVDtySJVR69UZiQget+CNahhLS6tq+dZTX1JZU0e028WMiYNb7X2VUqqtOaeE0MrX841JOHt4L9LskcxHy6r8ayXMmre+ld9RKaXalmMCQmurtddVHpYWzx+uGgvA0bJqf3C4anz/DsubUkq1hCMCgjGt14bgc7C4EoA+STGk2GsvF5ZXU1JRA1iL6BRVVHPjCyvZe0THKyilOj9HBAQAaeVKo2Q7CJzYJ5GUOKvH0bIdhymvtqbELq+qZf7aXBZvOcjfFm9r1fdWSqm24IxG5VaauiLQj88Zxol9ErlwVG9/aeFvi7P9x8ura6mwg0PoLKlKKdUZOSMgtEGVUbTHxUWj+wD1pYVA5dV1FJXreglKqa7DMT9dWzsgBIoJWFntynH9OW9EGpXVtZRUWiWEovIa//FFm/LYkV/SdplRSqkWckRAaIOZKxp13eQhpMRFUVZVS2mlFQh8fxdVVPOjlzK54YWV7ZgjpZSKjCMCArR+o3Jj0hJiSO8ZT05BGTlHrd5FxXZAWLf3KACllboWs1Kq83FEQGjNkcrN6Z8Sw/hBKdSZ+nmOSipqqKqp49pnVwD4l+RUSqnOxBkBAVp/qHKIPknWQ97jdpESZzUyHy2zGpVLq2pYvrN+euyEGEe05SuluhjHPJnausLow1+eR2WtVRXkG5fgU1JRQ0V1nf91eZVWGSmlOh9nBIR2qDFKjosCrJJBSkg31OLKGvKKKvyvSyprOFpW1SBwKKVUR3JMlVFjq521haSQgFBVU0dOQTkAI/smsvNQKeMf+MgfJCoCBrGVVdWweEteu+VVKaV8HBEQoO2rjAK5XcKLN04M2rc9v4Q+SV7OHt7Lv++LbYcAOPfRTzj7kcUA/OG9zdz4Qqa/R5JSSrUXRwSE9uxl5DO6XxIAE4emAr6AEEO8t76W7s431lFXZzhYXMmhkioOFFawalcBACt3HWn3PCulnC2igCAiU0QkS0SyRWRWmONeEZlrH18uIukBx+6x92eJyKX2vkEi8omIbBaRjSJyW2vdUOP30NbvECwt0cvOP17GjWelA7CvoJyUuGgSvMHNNlsPFvu3563JYdfhUgCdIVUp1e6aDQgi4gaeBKYCo4FrRGR0SLKZQIExZjgwG3jEPnc0MAMYA0wBnrKvVwPcaYwZBUwCbg5zzVbT/uUDi4j4l9qsrKkj0esJKiEA3D53nX+7usb4F9jZa7c5KKVUe4mkhDARyDbG7DDGVAFzgGkhaaYBL9rbbwIXitWKOw2YY4ypNMbsBLKBicaY/caY1QDGmGJgMzDg+G8nPGPatw0hUOCYg3ivm3ivNe9RD3uswub9Rf7j//dJ/TTZBWVV7ZRDpZSyRBIQBgB7A17n0PDh7U9jjKkBCoGekZxrVy+dAiyPPNvHrj17GQUKrCKK93oY0z+ZYWnxzJo6skHa6lqrLONxCYVlOlOqUqp9RRIQwj1JQ2thGkvT5LkikgC8BfzSGFMUJi0icpOIZIpIZn5+fgTZDfeGHVVpBIkBJYREr4fhvRNYfOf5TD9tEC770zl3RBpx0fUzpsZFuynUqbOVUu0skoCQAwwKeD0QyG0sjYh4gGTgSFPnikgUVjB41Rgzr7E3N8Y8Y4yZYIyZkJaWFkF2w12jA6uMAkoIgdVHLpdgL8tMbV0dZQGjl8cOSOZwaRXZB3WabKVU+4kkIKwEMkRkqIhEYzUSzw9JMx+4zt6+GlhsrL6e84EZdi+koUAGsMJuX3gW2GyM+Utr3EizOigiBP7yP2Vwj7BpBqTEMn5QCgDPXHsaPzxzCADT//5Vo2snLFi/X9dVUEq1qmanrjDG1IjILcBCwA08Z4zZKCIPAJnGmPlYD/eXRSQbq2Qwwz53o4i8DmzC6ll0szGmVkTOBq4F1ovIWvutfmWMWdDaN2jloy2uGpnAtosx/ZOCjv3hqrHMWbmHh751EmXVtUS7Xf7FdsYOSGLDviIu+sun7Pjj5UHnFVdU8/NXVxMf7WbjA1Pa/iaUUo4Q0VxG9oN6Qci++wO2K4DpjZz7IPBgyL4vaOff7O21HkJT4qKDP+4fTBrCDyZZpYEkd3BhzWUHkjoD736dyzdP7u8/9j//WAZAqV3NtDWvmI825fGjc4bi9bhRSqmWcMRI5Y72zLWn8cC0Mcd0TnrPeP/2La+tCToW2FXVGMOfP8ziTwuz+O+6/ceXUaWUozkiIBhj2n2kcqBLxvTlh2emH9M5D35rLH+6+mT/64LS+nEJpwxO8W9PfeJz1trzHq3P0fmPlFIt54iAAB3Xy6ilEmOimD5hEHNumgTAwo0H/MfqAtpEthwoJq+oEoD9hRUopVRLOSIgdGCb8nHz9T46HFBCyC+q8O8PdKBIA4JSquWcERBM+09u11piotzER7s5XFJFTkEZF//lU3ILKzgno1eDtMdTQvjuP5by2/kbjyerSqkuzhEBATpHL6OWSk2I5khpJb9+ewPb7MFql5/cjzsvHuFPc2KfRA6VVFJVU8fTS7Y3O332rLe+5tdvb/C/Xr7zCC98tatV870+p5AnFm1rPqFSqlNwREDoyKkrWkO/5FhyCsqD7mJk3yR+cWGG//WE9B4YY7U1PPLBFqb/fSlgNah/75/LeHNVDje+sJLHF20FYM7Kvby8bDd1dW332Vzxf18we9FWamrrmk+slOpwjggI0HWrjMD69Z91oBivx/q6rpk4uEGa6yanA/DO2n3+fUfLqth7pJyvth/mrjfWsXjLQR5ftI3agCCwIbfQv3xnWymtbNvrK6VahyMCQkeOVG4NZw3vSXFlDQs3Wmst//bK+qUjfNNoD06NY0BKLJ9tPeQ/9sqy3WTublh1dLik0r/92dZ8jjYys+ravUe5ZPanFFUc30R7xZU6UZ9SXUFEI5W7OkPXLiGcf2Jv4qLdlFXVcmKfxKDRyPN+fhYrdh4mJspNRp8E9h2tX1jnsQ+3hr1efkBA2HawJGhm1el//4qff2M4KbFRfOuprwBYt/coJ6Ql8OLSXdxx8QheW76HsQOSOT3dWh70YHEF+cWVjOmfHPb9SiprWnzvSqn244iAYOm6ESEmys3p6al8ujWf750RXF00tFc8Q3tZo5ovG9uPJVnNTxF+qKS+C2tJRU1QQFi5q4DHFmaxMbd+NPSR0iqufXYxAGkJXv7w3mYAFt1xLsN7J3L+n5ZQVlXLroeD51wKfA+lVOenVUZdRFKsVTUUuiZzoBN6x4fdf3p68Cyr++zlOfsmxVBcWdNg7YVxIWMcPt580L+9eX/9GtAX/eUzAP/U3VU19Y3Hge0UxXYJobyqltkfbWV7fknEa0bvPVKm60sr1U4cERCgY6euaA1R9mo6tU1Et5MHpnDjWUP5x7Wn+ffdfemJ/O7KsUHpthywfv2n94qjuKKGopCAUF0T3Ctoa15xg3N9DhbXj304FFAVVVZVXyrwlRD+tngbT3y8jQv//CnnPPpJo/cR6PzHlnDOo59gukNUV6qTc0yVURePB/4SgquJyBbldnH/FVaD892XnsjFo/swok8i+cWVQeleWrobr8dFv+RY9h090qCE4BvxPGvqSF5dvpsdh0r9x3YGbAPszK9/nVNQTv+UWICgBX98bQiBgSVSvpLGkdIqeiZ4j/l8pVTkHFFC6A4/Lu+4ZAQ/OW8YV4zrF1H6m78xnBF9EoH6nkiBjIGe8dHsPVLO3JV7g47lBFQp9U2K8VcFJXo9/gf9xKFWg3JuYX0j9q/f3uAvMZRWNiwh+NaMbonAdo+28NGmPNbnFLbpeyjV2TkiIEDX7mUEkBQTxT1TR7VovQNPwFoLT37vVACqaus484SeAGTlFRPtqU9z0C4hJMV66JMU498/pFecf/vyk6zAdPvcdf59WXnF3PLaGt5clcPXAQ/XBxds5pOsg0HVSBDcztCcwOqotvDjlzK54v++aNP3UKqzc0RA6AYFhOMWG+UmKcbDiX0T/fsCl/SMj3Zz3+WjSPR6/AvvxES56ZccEBDsNRpcAhl9EsK+z4qdR7jrjXX8cu7aoP03PL+yQdXVEXvCvto6w5OfZPun8fYJbDc4VFLJweIK0me9x4L1uu6DUm3BGQHBmC49l1FryLzvIpb96kL/A37coBRS46M5eaA1dqCgrJofnTOMC0f19p8TF+2hb3Ks//XAHtZ2vNfD2AHBYw76BpQkwnG7pEFA+DrnKH94dxMn/GoBf1qYxR2vr2XP4TL/yOnygBHU+cWV3Pcfa+6ll5fubnD9d9buI33WexwtC65aWrHzCOmz3mN7E+tPt/VIbaW6CkcEBOj6VUbHK97rIS7aQ7zXwws3nM5z100A4MUbJgIwaZjVJtAroOE2NqCEEOUW0uxjxlhVWM/a14Dw02kAfHj7udx72Shq64y/5OGzNa+Ef32x0/96R34p5/7pE+56w6qGChy/cLi0igL7Ye9xC3sOl3HbnDWUV9Xy0aY8bptjlUh2hDR6/2eNNZXHp02Mz2hspLZSTuOIgKBVRsHOP7G3v8dOj/hoPr37fP5xrfVw75VYHxDiot2MsKuG/nfKSNLsY75f7qcPTWVgj1h+d+UYbr1wONPGW+s+J8Z4gq4RWOrw8XpcrN8XfoW3DzZYiwEVBzRMP71ku//BvS2vhMc/3so7a3P579e5/PilTH+6vJApwH3zP5VW1vD0ku3+awOc8+hiHnxvE0fLwzdYF5ZXB3WrDVVdWxc0MrwlNuwr5MvsQ80nVKodOKLbqTFdv9tpWxoSsH5zYAkhJsrNoNQ4vpyz/H8WAAAVhklEQVR1Af2TY1i64zBQ3xicFBPFF/97gT/9+EEpvLM2lwEpsWw5YHUxTYqNIikmipX3XsT8dblcOa4/bpdwxd++YMH6+odzoBr7+sUhI5x9U38fKKqgxu6x9Oqy4Oqj/YUVPPLBFpZk5fPU90/1T+ldVFHNPz+3SiO7Hr6c6to69h4p55+f7yTBW98LyxiDMTBvzT7+9fkOthwoZsvvpxAT1bAx/6EFm3n+y12s+fXF9IiPDrrGvqPlDOwR1+CcUNf8cxnFFTWsvf9iUuKim02vVFtyRAkBQJxeZxShXgn1D6W4aOshOCAlFpH6KqPG+EodcdFuHvnOSdx72SiSYqyHbVqil5lnDyUt0UtqfHTQL+vYKDduV/D3U1RR7a8yeuhbJ/n3+0osX223flWvC+kqeri0kqeXbGfz/iK+8dgS//7A4DJvdQ5/+7h+nYbZi+rnfMrcXcBjH2Zx1xvr/EEt8Bf8y8t2c8FjS9hfWO7fnx3QPrEtr5jZH23l7Ec+YVdI9VU4vnwt39n0+hWB8ooqOKDLpao24IiAoKNcIxfahhBoQI9Y4u2HfThxdvp4r4fvnj6YH587rNH3+dVlI/3bmx64lOwHpwYdP/m3H1Jiz5I6flAKI+3eUd8YaVU/hY5LeGDaGJJiPGzLC994HNigfcfr6/jr4uyw6ab/fSlPLdketG/mi5m8unw3+wvL+f1/N7HjUCmfbzvk/6y25hVTUV3LtCe/5OLZn/mvHTiJ4N4jZZzz6GK+Cqkeiomy/guuPIaAMOXxz7jgz0si+nddUV3boKE9nI25hZzwqwXsOazThDiZIwKCipyvETkpxoMr5Fd7XLSHjQ9M4bunh29A7mufe/bwhst7hrrp3BMAyOidgIggInw56wJevHGiP83eI1YpIinWw0PfPokrxvXnZ+ed4J/PqU+S9UBOjY/mh2em0y85ls+3WQ9c38A5gHEDk9m8P3jKDZ+J6alh94e69z8bOPOPi6myF/t5/stdfLXdqkLLOlDMnBV7WBfSbfZIaRW3z13LE4u28c7afew9Us4Tdsnk+udX8Id3N1FRbV3viwjbEfKLKykoq6asqpbt+c2XQKb/fSnjH/io2XTPf7mL2jrD59lW4/sX2w7xr893NHteUUU19729PuwU6UUV1dqDq4txRhtCR2egC+mZ4OXZ6ya0aJqIsQOS+fD2c8noHX6MQqgvZ10Q1AA9ICWWXgnRJHg9lFTW8OACa1bVXgleBvaI41R73ITv+KmDe/D+hgP0tOvvU+OjycorxiXwy4sy+N4/l3PJ6D6UV9c2qFry+e2VY7jsr583msefnDcMr8fNG5l7g9asDgwwL4XpBgvwk5dX+bdPGWxNGHiopJKNuYUsycr3z0w7LC2eLQeKueP1tfzmijHsyC8hKTaKE9Ksz7Gooppot4uYKHfQXFJ7jpQyvInPuq7OsH6fdd8FpVVB7Rxgjc7+84dZPHv96eTZgxF91YQ/eHY5AJeO6cug1Pq2kEWb8jhreC9i7XQvfbWLV5btoW9SDLdckBF0/ZN/+yGnDenBWz+b3Ggej9f+wnJ2Hy5j0rCeLTrfGKPTogRwREDAaLfTY3HhqD4tPtc3XUYkBqTENtjn9bj5+M7zOOOhj/37Qht0fe0NF43qQ22d4e5LTwSstacBRvVLYvIJvfjo9nMZ3DOOu974utE8DOkZvuE3PtrNht9d6m97+sUFw/njgi089+VOxvRPCpoe3EcETh6QHDb4rNljlR6255fy55B1Kk4fksqO/FLmrd5HotfDi0t30zM+mhvOSmfvkXL+s3Yfo/omcnZGr6Cg9NLS3TyxaBsv3DCxwcMerJHjPhtzizg7o77kVldn/L2z3l6zj4NFVvVWYUgX3Oe+3MlvrhgDwOo9BfzopUyun5zOb6+09vmmIwntUuxr41i1u6BBvlrTeY8uoaq2ju0PXeb/d1FRXUudMcRFN/94+8tHW/nb4mwy77soqLrUqRxTZeT0gWldSe/Epv9jjrd/bQ9Ni+eZH04gww5CMfa0Hr5qrwx7MaE+9vVOHpjM9ZPT+eO3T+LZ6ybw1s8mE+/18Ox1E3j++tOD3qO0qjaoI4Jv4sCdf7yMwfYvZq/HxQPTxvjT3Hf5aC4/2ZrSY8bpg4h2u/j1N0f7079kV4ct3lI/nTjAOSPqH9TvrMsFrHEXj324lbmZe6mqqWNdTiFPfrKdeav30Sshmii3sCQrn3U5hZzy+4+4Y+5afzvJx5vzuHT2Z0x9or7kszE3OEiNf+DDoGN77CnGC+yA4KuOC6wG22o3su86XF9VtcJu+8gvruTWf68hfdZ7fPupL3kloPdXtV3NVlpZ498OdO9/1nPPvPUN9jenrs74q/B8U7oDXP33rxj/u+aryYwx/M1u78ncFXkbTnfmiBKCVhl1LSLCJ3edz+ItB5kwpEeD43+YNpb/mTCIU0LWbehplxB8Dc8+vrEVg1Pj/L9sA/lKRNNPG8ig1Dj+8tFW/6jscHm7ZuJg3t9wgHk/n8zofkmM6Z/MKYNScLkEYwynDUnl1MEpPPydkwH45sn9qKiu9QcGgJS4KI6WVSMS3OYSySC5sQOSSY6N4p21uUS5hepaw7w1+9h1uJTXfjyJmS/Wj8sYPyiFg0UVbMwt4q1VOXy46QD/uHYCRQG9rgK7/z7x8TbW5Rwlr6iSKLewYV8RlTW1eD1uPrfbOfKLK1m39yj9kmNYvccqAby5Ksd/jdV7jrJ6T30gybj3ff/2t04ZwCPfOZnZi7Yyok8CC9Yf4KNN1tKwd14ygl4JXr7OOcr2/BKuOLk/CzfmsetwKROHpnJ6eip1dYasvGIWbznIeSPS/NfNzi9msF3a27DPKr1t3l/EkJ5xjZYUngjoaZZ1oIQpY8MmcxRnBATT9ddDcJqhveKZefbQsMd6xEcHPQx8br0wgyvH9W8wrcZZJ/TijKGp/qqlxvxp+jjAahwPV53lc+6INLY9OJUoe9LA0wKClogEvQaCJgh89xdnc/ebXzNr6kjyCis4sW8iKXHRfHT7udzy2hqy8oq5eHQfbr0ggzpj+N1/N3K0rJphafEYA4uzDnLTOcOI93p4Z20uN5w1lGc+sxp/V+85yshff+B/r4npqbw0cyI/e2UV89flMt8ufWwLqEq6+rSBvLkqh2i3i5MGJrNqd4G/beNn553AXxdns/NQKSP7JrHZribbmFvEtCe/DLg/L3lFlSTGWKPgv/P00kY/u/+s2cfyHYfJDdNt9o3MHF7P3OufYv1oWTW/++8mwGo3Wv/bS7j5tdW8bw8u/NPCLP+52QdLuGBkHwpK63tU+UpIPz//BO6+9EQOFley81ApE9NTqaypY5k9riYt0cuOQyV8vi2fg0WVJMR4GNM/Kew4kro6w4pdRzg9PbVBV2mw1jH/bGs+fZJi+P1VDSPMqt0FDE6N44WvdnLjWUMbtF18mX2I8YNSiG9iIay2FNG7isgU4AnADfzLGPNwyHEv8BJwGnAY+K4xZpd97B5gJlAL3GqMWRjJNVubxoPuLyHMHEsAJw1MZu5Pzoz4Ov8zYVCzaaLcLattHTsgmfdvO6fB/ow+iSTb05RPPqEnJ9lzTL3x08kYY/wz1tbU1vm35/18MqP6JnH3pSfy+bZ8bnzBKhkMSo3l07u+4e8llhwbPP35xbOtle5mf3ccgvDmqhxOGpjMxaP7BNX5XzKmL39dnM13nvrK30YwMT2VFSHVK7OmjuT2ueu46ZxhnDYklSE949gd0n11wpAeHCmtYseh0gbB4PrJ6by8bDePfLAFgGi3i6raOn8wSPR6KK6sYf66XN7fcIDU+Gj/xIhj+idRVlXLQwu2UF1r/EFiWK94/zQmTy3Zzta8EhZtzmvwuX/n1IEcLK5g56FSrn12hX//xKGpvG7/myksr+aZz7bz8/OH8+ry3Ty0YAu3XzSC2y4KbkQvLK/mvrc3+F+PH5TCd04b6H/93tf7ufm11cREuaioruO15Xv44n8vwO0SYqLcvLJsN/e9vYHvnTE4aOxNe2o2IIiIG3gSuBjIAVaKyHxjzKaAZDOBAmPMcBGZATwCfFdERgMzgDFAf2CRiIywz2numq1Gq4xUV/D9MwazYueRoNKP9Su0/udM4FTmpwbMVjtpWE/OHZHGVeP7M2Vs36Auw3dPGUlVbR2xUR7eWl1ftXPluAHU2aOqrxzXn7REL4N6xFFQVsWQnnH+Xk6BDcZ3XjKC7z6zDIApY/ry7VMHcMmYvpw3ord/3Y05N01i1lvrOXVwD3rER3HSgGROGdyDujrDT15ZxaAecWT0SeCeeev5948nceYJPSmuqOGt1Tlk9E5gwW3n8GlWPj+yG72f+sGpXPvsCm6bs5bBqXG8e+vZzF+by31vb+C+y0ezLucoD7+/JajE8Nz1p7N6TwE1tYanP90eNhgAXHVKfxZtyuPFkJ5iK3Ye4cw/fsxD3zqJFbuO8PSS7Szdftgf6GYv2srsRVu5fnI691w2kp+8vKrByoOvrdjD9vwSnlqy3V9FCPi7GheUVTPmNwsBazr59+xZfN9dl8v93xzNZ1vzOXdEWthR8m1FmhvcIiJnAr81xlxqv74HwBjzx4A0C+00S0XEAxwA0oBZgWl96ezTmrxmOBMmTDCZmZlNJQlr9P0f8P0zBnPv5aOP+Vyl2lNZVU1EvWNaavWeAr791Fc8evXJEZWEvvP0V6zfV8iT3zuV/OJKvnfGYOau3EPvxJgGbTXHKvBeK6pr+XzbIc48oScJXg91ddaDfNKwVE4d3IOMe9+nps6w6r6L6JngxRhrssQEr4fq2jq+989lrNxllW5G9k3kg1+e63+fL7Yd8nejXfebS3C7hOyDJazaXcDMs4eyeX8RDy3YzFfbDzOiTyJj+icxf11u0BrhgWZNHcnD728Je+yiUb15fMYpvLR0F49+kNXgeGyUm/LqWk4ZnELu0XLyiuoHLw7rFc8tFwznjtfXcdX4/ry91qrie/0nZwaNqzlWIrLKGDOh+ZSRVRkNAAKX1MoBzmgsjTGmRkQKgZ72/mUh5w6wt5u7ZqvRgcqqq2jLYABWqWLXw5dHnP7VH51BRXVt0DxLjQ1MPFaB9xoT5ebi0fXdnV0u4eZvDPe//uSu8ymqqPbXuYuIf4BilNvFv354Ok8uyebaSUP8kzD6nJ3Ri8emj6NnQrS/+mz8oBTG250SRvVL4uWZZ1BSWeO/5i8uGM6/V+zl+S93UllTx68uG8lLS3fjEuH6yemM6JPA4ZIq/vzhVv+Ss32SvPzUHjj5o7OH8e8VezhUXMV/bp5M/5RYEr0elmzN54bnV3La4B689dPJHCqt5KH3NvP22ly+feoAu0Tg8gcDsOa7Om1wD17/aeTVni0Vyb++cNXvoY/YxtI0tj9cBWzYx7aI3ATcBDB4cMv+IU4Z25dR/ZJadK5SThYT5W7XKovGBA6OCyc5LopfXTaq0eNXB9TlNyYhoCF3SM94Zk0dyQ1npfP5tkNcMa4f105KR8T6TC4YafdMmzCIz7bmk9EngX4Ba4dEe1y8+4tzwOBvGwI4f0Qaz99wOqenp+JyCb0TY3j06nFcfnJ/zhuRRrTHxVPfP5X31x/g1gszcLmEJxZtbXIt9dbkiCojpZRyqmOpMoqkq8RKIENEhopINFYj8fyQNPOB6+ztq4HFxoo084EZIuIVkaFABrAiwmsqpZRqR81WGdltArcAC7G6iD5njNkoIg8AmcaY+cCzwMsikg0cwXrAY6d7HdgE1AA3G2NqAcJds/VvTymlVKSarTLqTLTKSCmljk1rVxkppZRyAA0ISimlAA0ISimlbBoQlFJKARoQlFJK2bpULyMRyQfCr1fYvF5AZAvXdh96z86g99z9Hc/9DjHGNJwvPowuFRCOh4hkRtr1qrvQe3YGvefur73uV6uMlFJKARoQlFJK2ZwUEJ7p6Ax0AL1nZ9B77v7a5X4d04aglFKqaU4qISillGpCtw8IIjJFRLJEJFtEZnV0flqLiAwSkU9EZLOIbBSR2+z9qSLykYhss//uYe8XEfmr/Tl8LSKnduwdtJyIuEVkjYi8a78eKiLL7Xuea0+pjj3t+lz7npeLSHpH5rulRCRFRN4UkS32931md/+eReR2+9/1BhH5t4jEdLfvWUSeE5GDIrIhYN8xf68icp2dfpuIXBfuvSLVrQOCiLiBJ4GpwGjgGhHpLgsr1wB3GmNGAZOAm+17mwV8bIzJAD62X4P1GWTYf24Cnm7/LLea24DNAa8fAWbb91wAzLT3zwQKjDHDgdl2uq7oCeADY8xIYBzWvXfb71lEBgC3AhOMMWOxpsifQff7nl8ApoTsO6bvVURSgd9gLUE8EfiNL4i0iDGm2/4BzgQWBry+B7ino/PVRvf6DnAxkAX0s/f1A7Ls7X8A1wSk96frSn+AgfZ/lAuAd7GWaT0EeEK/c6z1Ns60tz12OunoezjG+00Cdobmuzt/z9Sv0Z5qf2/vApd2x+8ZSAc2tPR7Ba4B/hGwPyjdsf7p1iUE6v9h+eTY+7oVu4h8CrAc6GOM2Q9g/93bTtZdPovHgf8H1NmvewJHjTE19uvA+/Lfs3280E7flQwD8oHn7Wqyf4lIPN34ezbG7AMeA/YA+7G+t1V07+/Z51i/11b9vrt7QAi3MnW36lYlIgnAW8AvjTFFTSUNs69LfRYi8k3goDFmVeDuMElNBMe6Cg9wKvC0MeYUoJT6aoRwuvw921Ue04ChQH8gHqvKJFR3+p6b09g9tuq9d/eAkAMMCng9EMjtoLy0OhGJwgoGrxpj5tm780Skn328H3DQ3t8dPouzgCtFZBcwB6va6HEgRUR8y8EG3pf/nu3jyVhLvHYlOUCOMWa5/fpNrADRnb/ni4Cdxph8Y0w1MA+YTPf+nn2O9Xtt1e+7uweElUCG3TshGqthan4H56lViIhgrWW92Rjzl4BD8wFfT4PrsNoWfPt/aPdWmAQU+oqmXYUx5h5jzEBjTDrWd7nYGPN94BPgajtZ6D37Pour7fRd6pejMeYAsFdETrR3XYi1Rnm3/Z6xqoomiUic/e/cd8/d9nsOcKzf60LgEhHpYZesLrH3tUxHN6q0Q6PNZcBWYDtwb0fnpxXv62ysouHXwFr7z2VYdacfA9vsv1Pt9ILV42o7sB6rB0eH38dx3P/5wLv29jBgBZANvAF47f0x9uts+/iwjs53C+91PJBpf9dvAz26+/cM/A7YAmwAXga83e17Bv6N1UZSjfVLf2ZLvlfgRvves4EbjidPOlJZKaUU0P2rjJRSSkVIA4JSSilAA4JSSimbBgSllFKABgSllFI2DQhKKaUADQhKKaVsGhCUUkoB8P8BzfXN2i0p9S8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fd858cf6518>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from scipy.signal import correlate\n",
    "def calculate_natural_distribution(model, X_adv, y_adv, batch_size=256, kernel=np.array([0.1, 0.2, 0.4, 0.2, 0.1])):\n",
    "    with torch.no_grad(), lib.training_mode(model, is_train=False):\n",
    "        logits = lib.process_in_chunks(lambda X_batch:model(X_batch.to(device)), X_adv, batch_size=batch_size)\n",
    "        sorted_ans = logits.topk(k=logits.shape[1], dim=1).indices.to('cpu')\n",
    "        y_rank = (sorted_ans == y_adv.view(-1,1)).argmax(dim=1)\n",
    "        bin_counts = np.bincount(lib.check_numpy(y_rank), minlength=logits.shape[-1]).astype('float32')\n",
    "    soft_counts = correlate(bin_counts, kernel)[(len(kernel) - 1) // 2: (len(kernel) - 1) // 2 + len(bin_counts)]\n",
    "    soft_counts[0] = 0\n",
    "    assert len(soft_counts) == len(bin_counts)\n",
    "    return soft_counts / soft_counts.sum()\n",
    "\n",
    "soft_counts = calculate_natural_distribution(models.resnet18(pretrained=True).to(device), X_adv, y_adv)\n",
    "\n",
    "plt.plot(soft_counts)\n",
    "\n",
    "cumsum = soft_counts.cumsum() / soft_counts.sum()\n",
    "cumsum = torch.as_tensor(cumsum, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_params = set(model.editable.module[0].parameters())\n",
    "old_params = [param for param in model.parameters() if param not in new_params]\n",
    "\n",
    "training_opt = lib.OptimizerList(\n",
    "    torch.optim.SGD(old_params, lr=1e-5, momentum=0.9, weight_decay=1e-4),\n",
    "    torch.optim.SGD(new_params, lr=1e-3, momentum=0.9, weight_decay=1e-4),\n",
    ")\n",
    "\n",
    "trainer = EditableMatchRankTrainer(cumsum, model=model,\n",
    "          stability_coeff=0.03, editability_coeff=0.03,\n",
    "          experiment_name=experiment_name,\n",
    "          error_function=classification_error,\n",
    "          opt=training_opt, max_norm=10)\n",
    "\n",
    "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', '<br>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm_notebook, tnrange\n",
    "from IPython.display import clear_output\n",
    "\n",
    "# Learnign params\n",
    "eval_batch_cd = 500\n",
    "val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n",
    "min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']\n",
    "early_stopping_epochs = 500\n",
    "number_of_epochs_without_improvement = 0\n",
    "            \n",
    "def edit_generator():\n",
    "    while True:\n",
    "        for xb, yb, lg in torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2):\n",
    "            yield xb.to(device), torch.randint_like(yb, low=0, high=max(y_test) + 1, device=device)\n",
    "\n",
    "edit_generator = edit_generator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    \n",
    "    for x_batch, y_batch, logits in tqdm_notebook(train_loader):\n",
    "        trainer.step(x_batch.to(device), logits.to(device), *next(edit_generator))\n",
    "        \n",
    "        if trainer.total_steps % eval_batch_cd == 0:\n",
    "            val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n",
    "            clear_output(True)\n",
    "\n",
    "            error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']\n",
    "\n",
    "            number_of_epochs_without_improvement += 1\n",
    "\n",
    "            if error_rate < min_error:\n",
    "                trainer.save_checkpoint(tag='best_val_error')\n",
    "                min_error = error_rate\n",
    "                number_of_epochs_without_improvement = 0\n",
    "\n",
    "            if drawdown < min_drawdown:\n",
    "                trainer.save_checkpoint(tag='best_drawdown')\n",
    "                min_drawdown = drawdown\n",
    "                number_of_epochs_without_improvement = 0\n",
    "\n",
    "            trainer.save_checkpoint()\n",
    "            trainer.remove_old_temp_checkpoints()\n",
    "\n",
    "            if number_of_epochs_without_improvement > early_stopping_epochs:\n",
    "                break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Eval metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.load_checkpoint(path='best_val_error');\n",
    "# if you're not running this in the same notebook, you can also select the checkpoint via path\n",
    "# trainer.load_checkpoint(path='./logs/EXPERIMENTNAMEHERE/checkpoint_best_val_error.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adv metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = [0.485, 0.456, 0.406]\n",
    "std = [0.229, 0.224, 0.225]\n",
    "\n",
    "test_transform = transforms.Compose(\n",
    "    [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)])\n",
    "\n",
    "thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1}\n",
    "two_hundred_to_1000 = dict(map(reversed, thousand_k_to_200.items()))\n",
    "\n",
    "X_adv, y_adv = zip(*datasets.ImageFolder(root=\"./imagenet-a/\", transform=test_transform))\n",
    "X_adv = torch.stack(X_adv)\n",
    "y_adv = torch.tensor(list(map(two_hundred_to_1000.get, y_adv)))\n",
    "\n",
    "X_test, y_test = map(torch.cat, zip(*val_loader))  # Read the whole test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d907ecd079674f3ebb264d87cca2ee00",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "base_error\t:0.30764\n",
      "drawdown\t:0.0014556\n",
      "mean_complexity\t:2.149\n",
      "success_rate\t:1.0\n"
     ]
    }
   ],
   "source": [
    "from lib import evaluate_quality\n",
    "from tqdm import tqdm_notebook\n",
    "\n",
    "np.random.seed(9)\n",
    "indices = np.random.permutation(len(X_adv))[:1000]\n",
    "X_edit = X_adv[indices].cuda()\n",
    "y_edit = y_adv[indices].cuda()\n",
    "metrics_adv = evaluate_quality(model, X_test, y_test, X_edit, y_edit, \n",
    "                           error_function=classification_error, progressbar=tqdm_notebook)\n",
    "\n",
    "for key in sorted(metrics_adv.keys()):\n",
    "    print('{}\\t:{:.5}'.format(key, metrics_adv[key]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
